[Common/PyTorch] Grouped-quantize kernels for 1D and 2D FP8 block-scaling#3135
Conversation
Greptile SummaryAdds fused grouped-tensor FP8 block-scaling quantize (1D and 2D) and dequantize kernels for Hopper (SM90-SM99), plus full PyTorch integration into
Confidence Score: 4/5The new CUDA kernels and C++ dispatchers are well-constructed; the two defects are both in the Python test helper, not in the production path. The make_quantizer helper in test_grouped_tensor.py sets force_pow_2_scales=True for the fp8_blockwise case, while Float8BlockQuantizer::create_grouped_tensor now explicitly rejects that flag — every test parameterized with _quantization_params[fp8_blockwise] that calls grouped_tensor.quantize() will raise NVTE_ERROR on Hopper before any kernel fires. Separately, the skip-condition for that param entry uses fp8_block_scaling_available rather than fp8_block_scaling_grouped_available, allowing the tests to run on Blackwell where the SM90-SM99 guard would immediately reject them. tests/pytorch/test_grouped_tensor.py — the make_quantizer helper and the _quantization_params skip condition for fp8_blockwise both need correction. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[Python group_quantize] --> B{quantizer type?}
B -->|Float8BlockwiseQuantizer| C[FP8_BLOCKWISE_GROUPED_QUANTIZE]
B -->|MXFP8| D[MXFP8 path]
B -->|NVFP4| E[NVFP4 path]
C --> G[create_grouped_tensor - NVTE_CHECK force_pow_2_scales==false]
G --> H{scaling_mode?}
H -->|BLOCK_SCALING_1D| I[group_quantize_blockwise_1d]
H -->|BLOCK_SCALING_2D| J[group_quantize_blockwise_2d]
I --> K{RW-only no dbias?}
K -->|yes| L[group_block_scaled_1d_rw_kernel - no smem]
K -->|no| M[group_block_scaled_1d_tma_kernel - TMA CW/BOTH]
J --> N[group_block_scaled_2d_tma_kernel - TMA pass1 amax pass2 quant]
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
A[Python group_quantize] --> B{quantizer type?}
B -->|Float8BlockwiseQuantizer| C[FP8_BLOCKWISE_GROUPED_QUANTIZE]
B -->|MXFP8| D[MXFP8 path]
B -->|NVFP4| E[NVFP4 path]
C --> G[create_grouped_tensor - NVTE_CHECK force_pow_2_scales==false]
G --> H{scaling_mode?}
H -->|BLOCK_SCALING_1D| I[group_quantize_blockwise_1d]
H -->|BLOCK_SCALING_2D| J[group_quantize_blockwise_2d]
I --> K{RW-only no dbias?}
K -->|yes| L[group_block_scaled_1d_rw_kernel - no smem]
K -->|no| M[group_block_scaled_1d_tma_kernel - TMA CW/BOTH]
J --> N[group_block_scaled_2d_tma_kernel - TMA pass1 amax pass2 quant]
|
- Reuse shared helpers (DIVUP, DIVUP_TO_MULTIPLE, TMA_GMEM_ALIGNMENT, align_smem_ptr_per_TMA_requirements, get_current_tensor_id, subwarp_reduce_max_broadcast) in place of local equivalents. - Add proxy-async fence after mbarrier_init in 2D + 1D TMA kernels. - Enforce per-tensor first_dim % 128 device-side for VARYING_FIRST_DIM (matches MXFP8 grouped quantize behavior). - Fix Hopper SM range wording in 1D dispatcher. - Extend cpp tests to cover with_gemm_swizzled_scales path. Signed-off-by: Alp Dener <adener@nvidia.com>
| // num_tiles_X = DIVUP(total_row_blocks, TILE_DIM_X=4) | ||
| __device__ __forceinline__ size_t swizzled_colwise_scale_idx(size_t i, size_t j, | ||
| size_t total_row_blocks) { | ||
| using namespace transformer_engine::dispatch::mxfp8::swizzle; |
There was a problem hiding this comment.
I think we should rename the namespace for swizzle...given that we use the same constants for mxfp8, nvfp4, fp8 block scaling
There was a problem hiding this comment.
Grouped GEMM doesn't read FP8 block-scales in swizzled format. It requires a compact per-expert format instead, so I stripped out all the swizzle changes out of this PR.
dc998bd to
6a25307
Compare
1c17b49 to
20c98e5
Compare
|
/te-ci cpp pytorch |
20c98e5 to
ab816b5
Compare
ab816b5 to
7b115f5
Compare
|
/te-ci core pytorch |
vthumbe1503
left a comment
There was a problem hiding this comment.
LGTM! We should probably move group dequantize to a different PR or add a test for it in this PR.
| } | ||
| case NVTE_BLOCK_SCALING_1D: | ||
| case NVTE_BLOCK_SCALING_2D: { | ||
| fp8_blockwise::group_dequantize(&input, output, stream); |
There was a problem hiding this comment.
We would need group dequantize tests as well for fp8 blockwise.
Alternatively it can be seperated out into a different PR, since group dequant is not a priority for GroupedLinear integration
There was a problem hiding this comment.
PR already includes test_grouped_tensor.py::test_group_dequantize_fp8_blockwise so the dequantize is tested, but we're missing test_group_dequantize_cudagraph_capturable. I'll add it for parity with grouped MXFP8.
There was a problem hiding this comment.
Do we have CPP test in common though, considering JAX might also need it?
There was a problem hiding this comment.
We do now. Just pushed the commit with the new tests, both C++ and PyTorch, passing on H100.
No tests on the JAX side. The PR doesn't include any JAX changes. I punted that to a separate PR later this week because JAX needs a lot more changes to integrate FP8BS. Not quite as drop-in ready as PyTorch GroupedTensor.
| NVTE_CHECK(info.tensor_offsets_d != nullptr, | ||
| "VARYING_FIRST_DIM requires tensor_offsets to be set on the GroupedTensor."); | ||
| } | ||
| info.total_row_blocks = DIVUP(info.R_total, static_cast<size_t>(kTileDim)); |
There was a problem hiding this comment.
We might want to have static persistent kernel similar to mxfp8 in case of cuda graphs in future.
With cuda graphs total_rows can be larger than the total sum of first dims. And so we would be overlaunching thread blocks.
Not a blocker for the current PR. Persistent kernel based optimization can be a future PR if necessary.
There was a problem hiding this comment.
Sure, I'll turn that into a separate PR. Would be a self-contained change with no framework integration impact so it should be something very minimal to review and merge later.
7b115f5 to
b00947e
Compare
…ling
Implements grouped-tensor quantize for the FP8 1D (1x128) and 2D (128x128)
block-scaling recipes in row-wise (RW), column-wise (CW) and BOTH quantization
directions. A single CUDA kernel launch walks 128x128 tiles across every tensor
in the group, with each CTA decoding its owning tensor from the device-side
GroupedTensor metadata with (N, R, K) shapes. Supports SAME_BOTH_DIMS (all
tensors identical) and VARYING_FIRST_DIM (constant K, varying R) shape
representations.
Three kernels share the dispatcher in group_quantize_blockwise_{1d,2d}:
- group_block_scaled_1d_rw_kernel: RW-only dispatch; 8 threads/row, reads
global memory directly into vec-16 registers; bypasses TMA since the
shared-memory roundtrip and ptx::mbarrier do not buy anything without
re-use in the CW path.
- group_block_scaled_1d_tma_kernel: CW-only and BOTH dispatch. TMA bulk-load
fills shared memory input cache. BOTH runs an RW pass (8 threads/row,
vec-16 read from shared memory) then a CW pass; CW-only skips the RW
pass. The CW pass uses 4 t/col with 32-row reg_data and two column passes
in the BOTH instantiation (keeps the per-thread register footprint under
the sm_90 3-CTAs/SM threshold) and 2 t/col with 64-row reg_data in the
CW-only instantiation (avoids doubling the smem-load bank-conflict
footprint that 4 t/col would introduce).
- group_block_scaled_2d_tma_kernel: RW-only, CW-only and BOTH dispatch. TMA
bulk-load fills shared memory input cache. Pass 1 stages 8 IVecs/thread
in registers while computing the per-tile scalar amax. Pass 2 quantizes
from registers, emits row-wise output, stages column-wise output to the
shared memory transpose staging buffer, then drains smem_T to global
memory.
Per-expert scale offsets:
- 1D RW: closed-form O(1) for both SAME_BOTH_DIMS and VARYING_FIRST_DIM
(each M_i is a multiple of kTileDim=128, hence of kScaleColAlign=4, so
DIVUP_TO_MULTIPLE collapses and the prefix sum reduces to a single
tensor_offsets_ptr[tensor_id]/K load).
- 2D CW: closed-form O(1) for SAME_BOTH_DIMS; CTA-cooperative warp-shuffle
prefix sum for VARYING_FIRST_DIM (non-linear DIVUP_TO_MULTIPLE on
blocks_y_t prevents a closed form). The cooperative reduction uses the
existing warp_allreduce_sum helper from common/utils.cuh.
Dequantize and bias-gradient (bgrad):
- group_dequantize_fp8_blockwise.cuh: kernels for all four modes
(1D/2D x rowwise/columnwise), inverting the per-expert layouts the
quantize kernels write.
- bgrad_group_quantize accepts Float8Block quantizers and computes dbias
per-tile column-partial in-kernel (mirroring MXFP8); reduced per expert
via the existing common::grouped_reduce_dbias.
Scale constraints: the fused grouped FP8BS path supports only unconstrained
FP32 scales (Float8BlockQuantizer::create_grouped_tensor rejects
force_pow_2_scales=True). Power-of-2 scales remain available on the
non-grouped/unfused split-quantize path used for Blackwell MXFP8 emulation.
Tests: existing parametrized grouped quantize / dequantize / bgrad tests
in test_grouped_tensor.py cover MXFP8, NVFP4, FP8 current scaling and the
newly-added FP8 block scaling recipe. tests/cpp/operator/
test_cast_float8blockwise_grouped.cu adds 72 C++ unit-test cases over
uniform/jagged shapes, all four (BD x direction) modes, K in {128, 256,
512}, and CUDA-graph capture coverage.
Kernels are gated to Hopper (sm_90) at the host dispatcher (cuBlasLt
grouped GEMM supports FP8 block-scaling only on Hopper).
JAX integration is intentionally left out of scope and deferred to a
follow-up PR.
Resolves NVIDIA#2525
Signed-off-by: Alp Dener <adener@nvidia.com>
b00947e to
82ea6a3
Compare
Description
Implements grouped-tensor quantize for the FP8 1D (1x128) and 2D (128x128) block-scaling recipes in row-wise (RW), column-wise (CW) and BOTH quantization directions. A single CUDA kernel launch walks 128x128 tiles across every tensor in the group, with each CTA decoding its owning tensor from the device-side GroupedTensor metadata with (N, R, K) shapes. Supports
SAME_BOTH_DIMS(all tensors identical) andVARYING_FIRST_DIM(constant K, varying R) shape representations.Three kernels share the dispatcher in
group_quantize_blockwise_{1d,2d}:group_block_scaled_1d_rw_kernel— RW-only dispatch; 8 threads/row, reads global memory directly into vec-16 registers; bypasses TMA because the shared memory roundtrip andptx::mbarrierdoes not buy anything without re-use in CW path.group_block_scaled_1d_tma_kernel— CW-only and BOTH dispatch; TMA bulk-load fills shared memory input cache. BOTH runs RW pass first (8 threads/row, vec-16 read from shared memory) then CW pass (2 threads/column, 64-row register stage); CW-only skips the RW pass. CW path writes the transposed-FP8 tile to a shared memory transpose staging buffer, then drains to global memory.group_block_scaled_2d_tma_kernel— RW-only, CW-only and BOTH dispatch; TMA bulk-load fills shared memory input cache. Pass 1 stages 8 IVecs/thread in registers while computing the per-tile scalar amax. Pass 2 quantizes from registers, emits row-wise output, stages column-wise output to shared memory transpose staging buffer, then drains to global memory.Kernels are gated to Hopper (sm_90) at the host dispatcher (cuBlasLt grouped GEMM supports FP8 block-scaling only on Hopper).
PR includes PyTorch integration into te.GroupedTensor only.
PyTorch integration into te.GroupedLinear and JAX integration deferred to a follow-up PRs.
Partially resolves #2525
Performance
Benchmark on H200 with a sweep of grouped tensors in (N, M, K) shapes:
Two shape families per config:
SAME_BOTH_DIMS): all experts share the (M, K) shapeVARYING_FIRST_DIM): per-expert M drawn from an imbalanced routing, common KBuckets:
Bucket medians across 3 reps. Speedup is grouped vs the split-quantized fallback that loops over the grouped tensor and quantizes each constituent sequentially. % mono is grouped throughput relative to a single non-grouped FP8 block-scaling quantize on the equivalent monolithic (N·M, K) tensor.
Notes
Known Sub-Optimalities
1D CW load bank conflicts on ~35% of load wavefronts (reading from the shared memory input-cache)
CU_TENSOR_MAP_SWIZZLE_128Bhas the right pattern but caps FP16/BF16 at 64-elements; does not fit the 128-element tile for FP8 block-scaling without doubling per-tile launch overhead (quadrupling for FP32).1D BOTH reads the shared memory input-cache twice
2D CW/BOTH has bank conflicts on ~16% of store wavefronts (when writing to the shared memory transpose buffer)
No TMA-store
Type of change
Checklist: